# hypersense/strategy/greedy_uniform_alloc.py
from typing import List, Dict, Any, Callable
import numpy as np
from .greedy_important_first import GreedyImportantFirstStrategy

class GreedyUniformAllocStrategy(GreedyImportantFirstStrategy):
    """
    Same as GIF but allocate step_trials uniformly across groups.
    Groups are still formed by importance ranking, and importance is still updated each round.
    """

    def run(self) -> List[Dict[str, Any]]:
        total_trials_used = len(self.history)
        param_names = list(self.search_space.keys())

        best_trial = max(self.history, key=lambda t: t["score"])
        current_best_config = dict(best_trial["config"])
        current_best_score = best_trial["score"]

        max_full_trials = self.max_full_trials
        full_trials_used = 0
        round_idx = 0

        while total_trials_used < self.max_total_trials:
            configs = [t["config"] for t in self.history]
            scores = [t["score"] for t in self.history]

            # 1) importance
            if len(self.history) < self.min_trials_for_importance:
                importance = {k: 1.0 for k in self.search_space}
            else:
                importance = self.importance_evaluator(configs, scores)

            # 2) sort & group
            sorted_params = sorted(importance.items(), key=lambda x: -x[1])
            sorted_param_names = [x[0] for x in sorted_params]
            groups = [sorted_param_names[i:i+self.top_k] for i in range(0, len(sorted_param_names), self.top_k)]

            # 3) **Uniform allocation** instead of proportional
            if len(groups) == 0:
                break
            base = self.step_trials // len(groups)
            group_trials = [base] * len(groups)
            remainder = self.step_trials - sum(group_trials)
            if remainder > 0:
                group_trials[0] += remainder  # 把余数丢给第一个组（也可以丢给最重要组）

            # 4) Optimize each group
            round_has_improvement = False
            remaining_trials = self.max_total_trials - total_trials_used

            for group, budget in zip(groups, group_trials):
                allowed = min(budget, remaining_trials)
                if allowed <= 0:
                    break

                subspace = {k: self.search_space[k] for k in group}
                fixed_config = {k: current_best_config[k] for k in self.search_space if k not in group}

                optimizer = self.optimizer_builder(subspace, self.history, fixed_config, allowed)
                try:
                    optimizer_results = optimizer.optimize()
                except Exception as e:
                    print(f"[GIF-Uniform] Optimizer failed on group {group}: {e}")
                    continue

                new_trials = []
                for config, result, elapsed_time in optimizer_results:
                    full_config = dict(current_best_config)
                    full_config.update(config)
                    new_trials.append({
                        "config": full_config,
                        "score": result,
                        "elapsed_time": elapsed_time,
                        "round": round_idx,
                        "group": group,
                    })

                self.history.extend(new_trials)
                n_added = len(new_trials)
                total_trials_used += n_added
                remaining_trials -= n_added
                if remaining_trials <= 0:
                    break

                group_best = max(new_trials, key=lambda t: t["score"])
                if group_best["score"] > current_best_score:
                    current_best_score = group_best["score"]
                    current_best_config = dict(group_best["config"])
                    round_has_improvement = True

                self.logs.append({
                    "round": round_idx,
                    "group": group,
                    "trials": len(new_trials),
                    "best_score": group_best["score"],
                    "importance_snapshot": dict(importance),
                    "uniform_alloc": True,
                })

            # 5) full-space fallback
            full_quota_left = max_full_trials - full_trials_used
            remaining_trials = self.max_total_trials - total_trials_used
            full_budget = int(full_quota_left / ((remaining_trials) // self.step_trials + 1)) if remaining_trials > 0 else 0
            full_budget = min(full_budget, remaining_trials)

            if (not round_has_improvement) and full_quota_left > 0 and full_budget > 0 and total_trials_used < self.max_total_trials:
                group = param_names
                subspace = {k: self.search_space[k] for k in group}
                fixed_config = {}
                optimizer = self.optimizer_builder(subspace, self.history, fixed_config, full_budget)
                try:
                    optimizer_results = optimizer.optimize()
                except Exception as e:
                    print(f"[GIF-Uniform] Full optimizer failed: {e}")
                    continue

                new_trials = []
                for config, result, elapsed_time in optimizer_results:
                    full_config = dict(current_best_config)
                    full_config.update(config)
                    new_trials.append({
                        "config": full_config,
                        "score": result,
                        "elapsed_time": elapsed_time,
                        "round": round_idx,
                        "group": group,
                        "uniform_alloc": True,
                        "full_group": True,
                    })

                self.history.extend(new_trials)
                n_added = len(new_trials)
                total_trials_used += n_added
                full_trials_used += n_added
                remaining_trials -= n_added
                if remaining_trials <= 0:
                    break

                group_best = max(new_trials, key=lambda t: t["score"])
                if group_best["score"] > current_best_score:
                    current_best_score = group_best["config"]
                    current_best_config = dict(group_best["config"])

            round_idx += 1

        self.current_best_config = current_best_config
        self.current_best_score = current_best_score
        return self.history
